{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# SDGym Benchmark" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 10 | Train Loss -0.649\n", "Epoch 20 | Train Loss -0.688\n", "Epoch 30 | Train Loss -0.700\n", "Epoch 40 | Train Loss -0.707\n", "Epoch 50 | Train Loss -0.709\n", "Epoch 60 | Train Loss -0.731\n", "Epoch 70 | Train Loss -0.727\n", "Epoch 80 | Train Loss -0.728\n", "Epoch 90 | Train Loss -0.720\n", "Epoch 100 | Train Loss -0.731\n", "Epoch 10 | Train Loss -0.622\n", "Epoch 20 | Train Loss -0.677\n", "Epoch 30 | Train Loss -0.697\n", "Epoch 40 | Train Loss -0.715\n", "Epoch 50 | Train Loss -0.718\n", "Epoch 60 | Train Loss -0.715\n", "Epoch 70 | Train Loss -0.719\n", "Epoch 80 | Train Loss -0.711\n", "Epoch 90 | Train Loss -0.714\n", "Epoch 100 | Train Loss -0.717\n", "Epoch 10 | Train Loss -0.647\n", "Epoch 20 | Train Loss -0.687\n", "Epoch 30 | Train Loss -0.701\n", "Epoch 40 | Train Loss -0.711\n", "Epoch 50 | Train Loss -0.707\n", "Epoch 60 | Train Loss -0.728\n", "Epoch 70 | Train Loss -0.720\n", "Epoch 80 | Train Loss -0.732\n", "Epoch 90 | Train Loss -0.726\n", "Epoch 100 | Train Loss -0.724\n", "Epoch 10 | Train Loss 0.058\n", "Epoch 20 | Train Loss 0.023\n", "Epoch 30 | Train Loss -0.052\n", "Epoch 40 | Train Loss -0.045\n", "Epoch 50 | Train Loss -0.085\n", "Epoch 60 | Train Loss -0.110\n", "Epoch 70 | Train Loss -0.086\n", "Epoch 80 | Train Loss -0.104\n", "Epoch 90 | Train Loss -0.127\n", "Epoch 100 | Train Loss -0.134\n", "Epoch 10 | Train Loss 0.043\n", "Epoch 20 | Train Loss -0.001\n", "Epoch 30 | Train Loss -0.035\n", "Epoch 40 | Train Loss -0.044\n", "Epoch 50 | Train Loss -0.036\n", "Epoch 60 | Train Loss -0.051\n", "Epoch 70 | Train Loss -0.143\n", "Epoch 80 | Train Loss -0.118\n", "Epoch 90 | Train Loss -0.130\n", "Epoch 100 | Train Loss -0.113\n", "Epoch 10 | Train Loss 0.037\n", "Epoch 20 | Train Loss -0.038\n", "Epoch 30 | Train Loss -0.093\n", "Epoch 40 | Train Loss -0.075\n", "Epoch 50 | Train Loss -0.152\n", "Epoch 60 | Train Loss -0.146\n", "Epoch 70 | Train Loss -0.146\n", "Epoch 80 | Train Loss -0.173\n", "Epoch 90 | Train Loss -0.146\n", "Epoch 100 | Train Loss -0.137\n", "Epoch 10 | Train Loss -0.077\n", "Epoch 20 | Train Loss -0.087\n", "Epoch 30 | Train Loss -0.121\n", "Epoch 40 | Train Loss -0.153\n", "Epoch 50 | Train Loss -0.204\n", "Epoch 60 | Train Loss -0.207\n", "Epoch 70 | Train Loss -0.245\n", "Epoch 80 | Train Loss -0.240\n", "Epoch 90 | Train Loss -0.240\n", "Epoch 100 | Train Loss -0.224\n", "Epoch 10 | Train Loss -0.111\n", "Epoch 20 | Train Loss -0.133\n", "Epoch 30 | Train Loss -0.143\n", "Epoch 40 | Train Loss -0.200\n", "Epoch 50 | Train Loss -0.214\n", "Epoch 60 | Train Loss -0.187\n", "Epoch 70 | Train Loss -0.205\n", "Epoch 80 | Train Loss -0.217\n", "Epoch 90 | Train Loss -0.222\n", "Epoch 100 | Train Loss -0.189\n", "Epoch 10 | Train Loss -0.106\n", "Epoch 20 | Train Loss -0.138\n", "Epoch 30 | Train Loss -0.126\n", "Epoch 40 | Train Loss -0.199\n", "Epoch 50 | Train Loss -0.214\n", "Epoch 60 | Train Loss -0.203\n", "Epoch 70 | Train Loss -0.222\n", "Epoch 80 | Train Loss -0.161\n", "Epoch 90 | Train Loss -0.223\n", "Epoch 100 | Train Loss -0.240\n", "Epoch 10 | Train Loss -0.083\n", "Epoch 20 | Train Loss -0.094\n", "Epoch 30 | Train Loss -0.105\n", "Epoch 40 | Train Loss -0.117\n", "Epoch 50 | Train Loss -0.126\n", "Epoch 60 | Train Loss -0.129\n", "Epoch 70 | Train Loss -0.118\n", "Epoch 80 | Train Loss -0.143\n", "Epoch 90 | Train Loss -0.143\n", "Epoch 100 | Train Loss -0.145\n", "Epoch 10 | Train Loss -0.070\n", "Epoch 20 | Train Loss -0.106\n", "Epoch 30 | Train Loss -0.100\n", "Epoch 40 | Train Loss -0.122\n", "Epoch 50 | Train Loss -0.125\n", "Epoch 60 | Train Loss -0.135\n", "Epoch 70 | Train Loss -0.138\n", "Epoch 80 | Train Loss -0.124\n", "Epoch 90 | Train Loss -0.142\n", "Epoch 100 | Train Loss -0.144\n", "Epoch 10 | Train Loss -0.077\n", "Epoch 20 | Train Loss -0.100\n", "Epoch 30 | Train Loss -0.116\n", "Epoch 40 | Train Loss -0.128\n", "Epoch 50 | Train Loss -0.121\n", "Epoch 60 | Train Loss -0.146\n", "Epoch 70 | Train Loss -0.147\n", "Epoch 80 | Train Loss -0.150\n", "Epoch 90 | Train Loss -0.147\n", "Epoch 100 | Train Loss -0.148\n", "Epoch 10 | Train Loss 0.017\n", "Epoch 20 | Train Loss -0.029\n", "Epoch 30 | Train Loss -0.040\n", "Epoch 40 | Train Loss -0.033\n", "Epoch 50 | Train Loss -0.072\n", "Epoch 60 | Train Loss -0.064\n", "Epoch 70 | Train Loss -0.084\n", "Epoch 80 | Train Loss -0.067\n", "Epoch 90 | Train Loss -0.082\n", "Epoch 100 | Train Loss -0.076\n", "Epoch 10 | Train Loss -0.003\n", "Epoch 20 | Train Loss -0.022\n", "Epoch 30 | Train Loss -0.053\n", "Epoch 40 | Train Loss -0.076\n", "Epoch 50 | Train Loss -0.049\n", "Epoch 60 | Train Loss -0.090\n", "Epoch 70 | Train Loss -0.095\n", "Epoch 80 | Train Loss -0.105\n", "Epoch 90 | Train Loss -0.107\n", "Epoch 100 | Train Loss -0.120\n", "Epoch 10 | Train Loss 0.011\n", "Epoch 20 | Train Loss -0.022\n", "Epoch 30 | Train Loss -0.049\n", "Epoch 40 | Train Loss -0.055\n", "Epoch 50 | Train Loss -0.059\n", "Epoch 60 | Train Loss -0.067\n", "Epoch 70 | Train Loss -0.084\n", "Epoch 80 | Train Loss -0.095\n", "Epoch 90 | Train Loss -0.098\n", "Epoch 100 | Train Loss -0.090\n", "Epoch 10 | Train Loss -0.102\n", "Epoch 20 | Train Loss -0.129\n", "Epoch 30 | Train Loss -0.138\n", "Epoch 40 | Train Loss -0.148\n", "Epoch 50 | Train Loss -0.171\n", "Epoch 60 | Train Loss -0.144\n", "Epoch 70 | Train Loss -0.201\n", "Epoch 80 | Train Loss -0.182\n", "Epoch 90 | Train Loss -0.216\n", "Epoch 100 | Train Loss -0.252\n", "Epoch 10 | Train Loss -0.057\n", "Epoch 20 | Train Loss -0.089\n", "Epoch 30 | Train Loss -0.115\n", "Epoch 40 | Train Loss -0.120\n", "Epoch 50 | Train Loss -0.140\n", "Epoch 60 | Train Loss -0.174\n", "Epoch 70 | Train Loss -0.151\n", "Epoch 80 | Train Loss -0.199\n", "Epoch 90 | Train Loss -0.199\n", "Epoch 100 | Train Loss -0.201\n", "Epoch 10 | Train Loss -0.075\n", "Epoch 20 | Train Loss -0.092\n", "Epoch 30 | Train Loss -0.110\n", "Epoch 40 | Train Loss -0.118\n", "Epoch 50 | Train Loss -0.166\n", "Epoch 60 | Train Loss -0.163\n", "Epoch 70 | Train Loss -0.140\n", "Epoch 80 | Train Loss -0.159\n", "Epoch 90 | Train Loss -0.161\n", "Epoch 100 | Train Loss -0.158\n" ] } ], "source": [ "import numpy as np\n", "import pandas as pd\n", "import sdgym\n", "from echoflow import EchoFlow\n", "\n", "def EchoFlowSynthesizer(real_data, categorical_columns, ordinal_columns):\n", " df = pd.DataFrame(real_data)\n", " for i in categorical_columns+ordinal_columns:\n", " df[i] = df[i].astype(int).astype(str)\n", " \n", " model = EchoFlow(nb_epochs=100)\n", " model.fit(df)\n", " new_df = model.sample(num_samples=len(df))\n", " \n", " for i in categorical_columns+ordinal_columns:\n", " new_df[i] = new_df[i].astype(int)\n", " arr = new_df.values\n", " \n", " return arr\n", "\n", "def EchoFlowSynthesizeKDE(real_data, categorical_columns, ordinal_columns):\n", " df = pd.DataFrame(real_data)\n", " for i in categorical_columns+ordinal_columns:\n", " df[i] = df[i].astype(int).astype(str)\n", " \n", " model = EchoFlow(nb_epochs=100, use_kde=True)\n", " model.fit(df)\n", " new_df = model.sample(num_samples=len(df))\n", " \n", " for i in categorical_columns+ordinal_columns:\n", " new_df[i] = new_df[i].astype(int)\n", " arr = new_df.values\n", " \n", " return arr\n", "\n", "scores = sdgym.run(synthesizers=[\n", " EchoFlowSynthesizer, \n", " EchoFlowSynthesizeKDE\n", "], datasets=['ring', 'grid', 'gridr'], iterations=3)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | grid/syn_likelihood | \n", "grid/test_likelihood | \n", "gridr/syn_likelihood | \n", "gridr/test_likelihood | \n", "ring/syn_likelihood | \n", "ring/test_likelihood | \n", "timestamp | \n", "
---|---|---|---|---|---|---|---|
CTGAN | \n", "-8.760635 | \n", "-5.062972 | \n", "-8.309750 | \n", "-5.048310 | \n", "-6.591324 | \n", "-2.665281 | \n", "2020-10-17 09:46:54.494331 | \n", "
EchoFlowSynthesizer | \n", "-6.712230 | \n", "-4.437056 | \n", "-6.496902 | \n", "-4.475942 | \n", "-1.932969 | \n", "-1.796832 | \n", "2020-12-30 23:10:22.816115 | \n", "
EchoFlowSynthesizeKDE | \n", "-5.402527 | \n", "-4.063265 | \n", "-5.531003 | \n", "-4.154107 | \n", "-2.277480 | \n", "-1.842371 | \n", "2020-12-30 23:10:22.816115 | \n", "